{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 根据kaggle，写的一个无监督的小Test，方便刚入手voxelmoroph的同学理清程序逻辑\n",
    "# 环境为tensorflow2.x\n",
    "# 源码根据voxelmorph在Github上的做修改，只保留运行这个小Test所需要的所有代码，基本只做了删没有改。\n",
    "# 安装tensorflow2，修改路径后，应该可以直接跑起来并看到结果\n",
    "\n",
    "import os, sys\n",
    "\n",
    "# 引入第三方库\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "# 本地引入  自行修改路径\n",
    "sys.path.append('/home/xxx/xxx/tf2-vm-cvpr/')\n",
    "import voxelmorph as vxm\n",
    "\n",
    "# 读取数据  使用kaggle上提供的mri-2d数据\n",
    "core_path = '/home/xxx/xxx/mri-2d/'\n",
    "x_train = np.load(os.path.join(core_path, 'train_vols.npy'))\n",
    "x_val = np.load(os.path.join(core_path, 'validate_vols.npy'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ndims = 2  # 数据是二维\n",
    "vol_shape = x_train.shape[1:]  # 输入数据shape\n",
    "nb_enc_features = [32, 32, 32, 32]  # unet 通道数\n",
    "nb_dec_features = [32, 32, 32, 32, 32, 16]\n",
    "\n",
    "# unet层\n",
    "unet = vxm.networks.unet_core(vol_shape, nb_enc_features, nb_dec_features)\n",
    "\n",
    "# 卷积层\n",
    "disp_tensor = tf.keras.layers.Conv2D(ndims, kernel_size=3, padding='same', name='disp')(unet.output)\n",
    "\n",
    "# SpatialTransformer层\n",
    "moved_image_tensor = vxm.layers.SpatialTransformer(name='spatial_transformer')([unet.inputs[0], disp_tensor])\n",
    "\n",
    "vxm_model = tf.keras.models.Model(unet.inputs, [moved_image_tensor, disp_tensor])\n",
    "\n",
    "# losses. 一个是mse 一个是Grad-l2 ，两个相加，权重为1000:10 （个人行为）\n",
    "losses = ['mse', vxm.losses.Grad('l2').loss]\n",
    "\n",
    "loss_weights = [1000, 10]\n",
    "\n",
    "vxm_model.compile(optimizer='Adam', loss=losses, loss_weights=loss_weights)\n",
    "\n",
    "# 输出看一下model信息\n",
    "vxm_model.summary()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据生成器，输入为两张图，一个是moving，一个是fixed。输出初始化为fixed和空\n",
    "def vxm_data_generator(x_data, batch_size=32):\n",
    "    vol_shape = x_data.shape[1:]\n",
    "    ndims = len(vol_shape)\n",
    "    zero_phi = np.zeros([batch_size, *vol_shape, ndims])\n",
    "    \n",
    "    while True:\n",
    "        idx1 = np.random.randint(0, x_data.shape[0], size=batch_size)\n",
    "        moving_images = x_data[idx1, ..., np.newaxis]\n",
    "        idx2 = np.random.randint(0, x_data.shape[0], size=batch_size)\n",
    "        fixed_images = x_data[idx2, ..., np.newaxis]\n",
    "        inputs = [moving_images, fixed_images]\n",
    "        outputs = [fixed_images, zero_phi]\n",
    "        yield inputs, outputs      \n",
    "        \n",
    "train_generator = vxm_data_generator(x_train)\n",
    "# 测试一下数据生成器\n",
    "# 画图\n",
    "input_sample, output_sample = next(train_generator)\n",
    "slices_2d = [f[0,...,0] for f in input_sample + output_sample]\n",
    "titles = ['input_moving', 'input_fixed', 'output_moved_ground_truth', 'zero']\n",
    "vxm.plot.slices(slices_2d, titles=titles, cmaps=['gray'], do_colorbars=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练\n",
    "nb_epochs = 10\n",
    "steps_per_epoch = 10\n",
    "hist = vxm_model.fit_generator(train_generator, epochs=nb_epochs, steps_per_epoch=steps_per_epoch, verbose=1)\n",
    "\n",
    "# 保存模型\n",
    "# vxm_model.save_weights('cvpr-mri.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 推理\n",
    "val_generator = vxm_data_generator(x_val, batch_size = 1)\n",
    "val_input, _ = next(val_generator)\n",
    "val_pred = vxm_model.predict(val_input)\n",
    "\n",
    "# 显示结果\n",
    "slices_2d = [f[0,...,0] for f in val_input + val_pred]\n",
    "titles = ['input_moving', 'input_fixed', 'predicted_moved', 'deformation_x']\n",
    "vxm.plot.slices(slices_2d, titles=titles, cmaps=['gray'], do_colorbars=True)\n",
    "flow = val_pred[1].squeeze()[::3,::3]\n",
    "vxm.plot.flow([flow], width=10)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python(tensorflow)",
   "language": "python",
   "name": "tensorflow"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}